{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neutral Class Design Pattern\n",
    "\n",
    "This notebook demonstrates on a synthetic dataset\n",
    "that creating a separate Neutral class can be helpful.\n",
    "Then, carries this to a real-world problem."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### On synthetic dataset\n",
    "Patients with a history of jaundice will be assumed to be at risk of liver damage and prescribed ibuprofen while patients with a history of stomach ulcers will be prescribed acetaminophen. The remaining patients will be arbitrarily assigned to either category.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>jaundice</th>\n",
       "      <th>prescription</th>\n",
       "      <th>prescription_with_neutral</th>\n",
       "      <th>ulcers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>True</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>False</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>False</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>False</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>False</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>False</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>False</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>False</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>False</td>\n",
       "      <td>ibuprofen</td>\n",
       "      <td>Neutral</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>False</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>acetominophen</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   jaundice   prescription prescription_with_neutral  ulcers\n",
       "0      True      ibuprofen                 ibuprofen   False\n",
       "1     False  acetominophen                   Neutral   False\n",
       "2     False  acetominophen                   Neutral   False\n",
       "3     False      ibuprofen                   Neutral   False\n",
       "4     False      ibuprofen                   Neutral   False\n",
       "5     False      ibuprofen                   Neutral   False\n",
       "6     False  acetominophen                   Neutral   False\n",
       "7     False  acetominophen                   Neutral   False\n",
       "8     False      ibuprofen                   Neutral   False\n",
       "9     False  acetominophen             acetominophen    True"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "def create_synthetic_dataset(N, shuffle):\n",
    "    # random array\n",
    "    prescription = np.full(N, fill_value='acetominophen', dtype='U20')\n",
    "    prescription[:N//2] = 'ibuprofen'\n",
    "    np.random.shuffle(prescription)\n",
    "    \n",
    "    # neutral class\n",
    "    p_neutral = np.full(N, fill_value='Neutral', dtype='U20')\n",
    "\n",
    "    # 10% is patients with history of liver disease\n",
    "    jaundice = np.zeros(N, dtype=bool)\n",
    "    jaundice[0:N//10] = True\n",
    "    prescription[0:N//10] = 'ibuprofen'\n",
    "    p_neutral[0:N//10] = 'ibuprofen'\n",
    "\n",
    "    # 10% is patients with history of stomach problems\n",
    "    ulcers = np.zeros(N, dtype=bool)\n",
    "    ulcers[(9*N)//10:] = True\n",
    "    prescription[(9*N)//10:] = 'acetominophen'\n",
    "    p_neutral[(9*N)//10:] = 'acetominophen'\n",
    "    \n",
    "    df = pd.DataFrame.from_dict({\n",
    "        'jaundice': jaundice,\n",
    "        'ulcers': ulcers,\n",
    "        'prescription': prescription,\n",
    "        'prescription_with_neutral': p_neutral\n",
    "    })\n",
    "    \n",
    "    if shuffle:\n",
    "        return df.sample(frac=1).reset_index(drop=True)\n",
    "    else:\n",
    "        return df\n",
    "\n",
    "create_synthetic_dataset(10, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label=prescription accuracy=0.555\n",
      "label=prescription_with_neutral accuracy=1.0\n"
     ]
    }
   ],
   "source": [
    "df = create_synthetic_dataset(1000, shuffle=True)\n",
    "\n",
    "from sklearn import linear_model\n",
    "for label in ['prescription', 'prescription_with_neutral']:\n",
    "    ntrain = 8*len(df)//10 # 80% of data for training\n",
    "    lm = linear_model.LogisticRegression()\n",
    "    lm = lm.fit(df.loc[:ntrain-1, ['jaundice', 'ulcers']], df[label][:ntrain])\n",
    "    acc = lm.score(df.loc[ntrain:, ['jaundice', 'ulcers']], df[label][ntrain:])\n",
    "    print('label={} accuracy={}'.format(label, acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## On the Natality data\n",
    "\n",
    "Let's do this on real data.\n",
    "A baby with an Apgar score of 10 is healthy and one with an Apgar score of <= 7 requires some medical attention.\n",
    "What about babies with scores of 8-9? They are neither perfectly healthy, nor do they need serious medical intervention.\n",
    "Let's see how the model does with a 2-class model and with a 3-class model that includes a Neutral class.\n",
    "\n",
    "First, without the Neutral class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL mlpatterns.neutral_2classes\n",
    "OPTIONS(model_type='logistic_reg', input_label_cols=['health']) AS\n",
    "\n",
    "SELECT \n",
    "  IF(apgar_1min >= 9, 'Healthy', 'NeedsAttention') AS health,\n",
    "  plurality,\n",
    "  mother_age,\n",
    "  gestation_weeks,\n",
    "  ever_born\n",
    "FROM `bigquery-public-data.samples.natality`\n",
    "WHERE apgar_1min <= 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>f1_score</th>\n",
       "      <th>log_loss</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.565628</td>\n",
       "      <td>0.997893</td>\n",
       "      <td>0.565213</td>\n",
       "      <td>0.722007</td>\n",
       "      <td>0.690348</td>\n",
       "      <td>0.52722</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   precision    recall  accuracy  f1_score  log_loss  roc_auc\n",
       "0   0.565628  0.997893  0.565213  0.722007  0.690348  0.52722"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.EVALUATE(MODEL mlpatterns.neutral_2classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With 3 classes (including a neutral class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery\n",
    "CREATE OR REPLACE MODEL mlpatterns.neutral_3classes\n",
    "OPTIONS(model_type='logistic_reg', input_label_cols=['health']) AS\n",
    "\n",
    "SELECT \n",
    "  IF(apgar_1min = 10, 'Healthy', IF(apgar_1min >= 8, 'Neutral', 'NeedsAttention')) AS health,\n",
    "  plurality,\n",
    "  mother_age,\n",
    "  gestation_weeks,\n",
    "  ever_born\n",
    "FROM `bigquery-public-data.samples.natality`\n",
    "WHERE apgar_1min <= 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>f1_score</th>\n",
       "      <th>log_loss</th>\n",
       "      <th>roc_auc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.46499</td>\n",
       "      <td>0.333789</td>\n",
       "      <td>0.794872</td>\n",
       "      <td>0.296302</td>\n",
       "      <td>1.840975</td>\n",
       "      <td>0.553596</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   precision    recall  accuracy  f1_score  log_loss   roc_auc\n",
       "0    0.46499  0.333789  0.794872  0.296302  1.840975  0.553596"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%bigquery\n",
    "SELECT * FROM ML.EVALUATE(MODEL mlpatterns.neutral_3classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
